iT邦幫忙

2022 iThome 鐵人賽

DAY 1
0
自我挑戰組

resnet, resnext練習系列 第 1

cats&dogs dataset Resnet50練習 by pytorch

  • 分享至 

  • xImage
  •  

#Day1
Prepare 1. 下載壓縮檔 (下載位置如下)
ResNet-50 model source:https://www.kaggle.com/datasets/pytorch/resnet50?datasetId=6979
Cats v/s Dogs Resnet50 code source:https://www.kaggle.com/code/skbadhsm/cats-v-s-dogs-resnet50/notebook
Dogs vs. Cats dataset:https://www.kaggle.com/competitions/dogs-vs-cats

##Part 1.在程式中解壓縮檔(也可以先解完後直接使用)

### Unzipping Dataset
import zipfile

with zipfile.ZipFile("../input/dogs-vs-cats/train.zip","r") as z:
    z.extractall(".")
    
with zipfile.ZipFile("../input/dogs-vs-cats/test1.zip","r") as z:
    z.extractall(".")

##Part 2.載入封包

import os
import cv2
import time
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision.utils import make_grid
from torchvision.models import resnet50 #**

from sklearn.model_selection import train_test_split

from PIL import Image

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

#***
from torchvision.models import ResNet50_Weights
from torch.optim import Optimizer as optimizer

DIR_TRAIN = "/kaggle/working/train/"
DIR_TEST = "/kaggle/working/test1"

更改處為:(#***處)

1.resnet50
舊版的 :model = resnet50(pretrained = True) (# deprecated)
新版的 :model = resnet50(weights=ResNet50_Weights.DEFAULT)
details: https://pytorch.org/vision/stable/models.html
2.optimizer,在原code中的 optimizer我有時會掃不到,所以重新找到他的封包並載入使用
from torch.optim import Optimizer as optimizer

##Part 3. 載入圖片(訓練組跟測試組)

### Checking Data Format
imgs = os.listdir(DIR_TRAIN) 
test_imgs = os.listdir(DIR_TEST)

print(imgs[:5])
print(test_imgs[:5])

output:
['cat.0.jpg', 'cat.1.jpg', 'cat.10.jpg', 'cat.100.jpg', 'cat.1000.jpg']
['1.jpg', '10.jpg', '100.jpg', '1000.jpg', '10000.jpg']

##Part 4. 貓狗分群

### Class Distribution
dogs_list = [img for img in imgs if img.split(".")[0] == "dog"]
cats_list = [img for img in imgs if img.split(".")[0] == "cat"]

print("No of Dogs Images: ",len(dogs_list))
print("No of Cats Images: ",len(cats_list))

class_to_int = {"dog" : 0, "cat" : 1}
int_to_class = {0 : "dog", 1 : "cat"}

output:
No of Dogs Images: 12500
No of Cats Images: 12500

##Part5. Transforms Images - ToTensor and other augmentations
source :1.https://pytorch.org/vision/stable/generated/torchvision.transforms.Compose.html
2.https://pytorch.org/vision/main/generated/torchvision.transforms.ToTensor.html

def get_train_transform():
    return T.Compose([
        T.RandomHorizontalFlip(p=0.5),  #隨機水平翻轉
        T.RandomRotation(15),           #隨機旋轉
        T.RandomCrop(204),              #隨機部位縮放
        T.ToTensor(),                   #轉換PIL或numpy.ndarray到torch
        T.Normalize((0, 0, 0),(1, 1, 1))#對tensor後的image進行平均值跟標準差的正規化(R,G,B)
    ])
    
def get_val_transform():
    return T.Compose([
        T.ToTensor(),
        T.Normalize((0, 0, 0),(1, 1, 1))
    ])

##Past 6. 設定圖片和標籤的檢索Class

### Dataset Class - for retriving images and labels
class CatDogDataset(Dataset):
    
    def __init__(self, imgs, class_to_int, mode = "train", transforms = None):
        
        super().__init__()
        self.imgs = imgs
        self.class_to_int = class_to_int
        self.mode = mode
        self.transforms = transforms
        
    def __getitem__(self, idx):
        
        image_name = self.imgs[idx]
        
        ### Reading, converting and normalizing image
        #img = cv2.imread(DIR_TRAIN + image_name, cv2.IMREAD_COLOR)
        #img = cv2.resize(img, (224,224))
        #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
        #img /= 255.
        img = Image.open(DIR_TRAIN + image_name)
        img = img.resize((224, 224))
        
        if self.mode == "train" or self.mode == "val":
        
            ### Preparing class label
            label = self.class_to_int[image_name.split(".")[0]]
            label = torch.tensor(label, dtype = torch.float32)

            ### Apply Transforms on image
            img = self.transforms(img)

            return img, label
        
        elif self.mode == "test":
            
            ### Apply Transforms on image
            img = self.transforms(img)

            return img
            
        
    def __len__(self):
        return len(self.imgs)
    

##Part 7. 訓練組測試組分群

### Splitting data into train and val sets
train_imgs, val_imgs = train_test_split(imgs, test_size = 0.25)

##Part 8. 設定所要使用的data

Dataloaders

train_dataset = CatDogDataset(train_imgs, class_to_int, mode = "train", transforms = get_train_transform())
val_dataset = CatDogDataset(val_imgs, class_to_int, mode = "val", transforms = get_val_transform())
test_dataset = CatDogDataset(test_imgs, class_to_int, mode = "test", transforms = get_val_transform())

train_data_loader = DataLoader(
dataset = train_dataset,
num_workers = 4,
batch_size = 16,
shuffle = True
)

val_data_loader = DataLoader(
dataset = val_dataset,
num_workers = 4,
batch_size = 16,
shuffle = True
)

test_data_loader = DataLoader(
dataset = test_dataset,
num_workers = 4,
batch_size = 16,
shuffle = True
)
***PROBLEM
1.RuntimeError: if name == 'main' freeze_support()
代表你有資料在不同的FORK()中不停地交換導致超時,通常在跑迴圈的程式前加上"if name == 'main':" 就好,或者把 所有"num_workers = 4,"的行都#掉也可以

##Part 8. 視覺化隨機選的圖片

# if __name__ == '__main__':
### Visualize Random Images from Train set
for images, labels in train_data_loader:
    
    fig, ax = plt.subplots(figsize = (10, 10))
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(make_grid(images, 4).permute(1,2,0))
    plt.show()
    break

output:
https://ithelp.ithome.com.tw/upload/images/20220908/20151711TiU24HD4Ek.png


下一篇
cats&dogs dataset Resnet50練習 by pytorch Part 2.
系列文
resnet, resnext練習3
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言